#!/usr/bin/python
# -*- coding: utf-8 -*-
import re
import os
import shutil
import copy
import datetime
import numpy as np

from rouge import Rouge

from .logger import *

# from data import *

import sys

sys.setrecursionlimit(10000)

REMAP = {
    "-lrb-": "(",
    "-rrb-": ")",
    "-lcb-": "{",
    "-rcb-": "}",
    "-lsb-": "[",
    "-rsb-": "]",
    "``": '"',
    "''": '"',
}

############## 修改成自己的ROUGE路径 ##############
ROUGE_PATH = "/ROUGE/"


def clean(x):
    return re.sub(
        r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", lambda m: REMAP.get(m.group()), x
    )


def rouge_eval(hyps, refer):
    rouge = Rouge()
    # print(hyps)
    # print(refer)
    # print(rouge.get_scores(hyps, refer))
    try:
        score = rouge.get_scores(hyps, refer)[0]
        mean_score = np.mean(
            [score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]]
        )
    except:
        mean_score = 0.0
    return mean_score


def rouge_all(hyps, refer):
    rouge = Rouge()
    score = rouge.get_scores(hyps, refer)[0]
    # mean_score = np.mean([score["rouge-1"]["f"], score["rouge-2"]["f"], score["rouge-l"]["f"]])
    return score


def eval_label(match_true, pred, true, total, match):
    match_true, pred, true, match = (
        match_true.float(),
        pred.float(),
        true.float(),
        match.float(),
    )
    try:
        accu = match / total
        precision = match_true / pred
        recall = match_true / true
        F = 2 * precision * recall / (precision + recall)
    except ZeroDivisionError:
        F = 0.0
        logger.error("[Error] float division by zero")
    return accu, precision, recall, F


def pyrouge_score(hyps, refer, remap=True):
    ############## 要成功安装pyrouge哦！！！ ##############
    from pyrouge import Rouge155

    nowTime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    PYROUGE_ROOT = os.path.join("./", nowTime)
    SYSTEM_PATH = os.path.join(PYROUGE_ROOT, "gold")
    MODEL_PATH = os.path.join(PYROUGE_ROOT, "system")
    if os.path.exists(SYSTEM_PATH):
        shutil.rmtree(SYSTEM_PATH)
    os.makedirs(SYSTEM_PATH)
    if os.path.exists(MODEL_PATH):
        shutil.rmtree(MODEL_PATH)
    os.makedirs(MODEL_PATH)

    if remap == True:
        refer = clean(refer)
        hyps = clean(hyps)

    system_file = os.path.join(SYSTEM_PATH, "Reference.0.txt")
    model_file = os.path.join(MODEL_PATH, "Model.A.0.txt")
    with open(system_file, "wb") as f:
        f.write(refer.encode("utf-8"))
    with open(model_file, "wb") as f:
        f.write(hyps.encode("utf-8"))

    r = Rouge155(ROUGE_PATH + "RELEASE-1.5.5")

    r.system_dir = SYSTEM_PATH
    r.model_dir = MODEL_PATH
    r.system_filename_pattern = "Reference.(\d+).txt"
    r.model_filename_pattern = "Model.[A-Z].#ID#.txt"

    output = r.convert_and_evaluate(
        rouge_args="-e {}RELEASE-1.5.5/data -a -m -n 2 -d".format(ROUGE_PATH)
    )
    output_dict = r.output_to_dict(output)

    shutil.rmtree(PYROUGE_ROOT)

    scores = {}
    scores["rouge-1"], scores["rouge-2"], scores["rouge-l"] = {}, {}, {}
    scores["rouge-1"]["p"], scores["rouge-1"]["r"], scores["rouge-1"]["f"] = (
        output_dict["rouge_1_precision"],
        output_dict["rouge_1_recall"],
        output_dict["rouge_1_f_score"],
    )
    scores["rouge-2"]["p"], scores["rouge-2"]["r"], scores["rouge-2"]["f"] = (
        output_dict["rouge_2_precision"],
        output_dict["rouge_2_recall"],
        output_dict["rouge_2_f_score"],
    )
    scores["rouge-l"]["p"], scores["rouge-l"]["r"], scores["rouge-l"]["f"] = (
        output_dict["rouge_l_precision"],
        output_dict["rouge_l_recall"],
        output_dict["rouge_l_f_score"],
    )
    return scores


def pyrouge_score_all(hyps_list, refer_list, remap=True):
    from pyrouge import Rouge155

    nowTime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    PYROUGE_ROOT = os.path.join("./", nowTime)
    SYSTEM_PATH = os.path.join(PYROUGE_ROOT, "gold")
    MODEL_PATH = os.path.join(PYROUGE_ROOT, "system")
    if os.path.exists(SYSTEM_PATH):
        shutil.rmtree(SYSTEM_PATH)
    os.makedirs(SYSTEM_PATH)
    if os.path.exists(MODEL_PATH):
        shutil.rmtree(MODEL_PATH)
    os.makedirs(MODEL_PATH)

    assert len(hyps_list) == len(refer_list)
    for i in range(len(hyps_list)):
        system_file = os.path.join(SYSTEM_PATH, "Reference.%d.txt" % i)
        model_file = os.path.join(MODEL_PATH, "Model.A.%d.txt" % i)

        refer = clean(refer_list[i]) if remap else refer_list[i]
        hyps = clean(hyps_list[i]) if remap else hyps_list[i]

        with open(system_file, "wb") as f:
            f.write(refer.encode("utf-8"))
        with open(model_file, "wb") as f:
            f.write(hyps.encode("utf-8"))

    r = Rouge155(ROUGE_PATH + "RELEASE-1.5.5")

    r.system_dir = SYSTEM_PATH
    r.model_dir = MODEL_PATH
    r.system_filename_pattern = "Reference.(\d+).txt"
    r.model_filename_pattern = "Model.[A-Z].#ID#.txt"

    output = r.convert_and_evaluate(
        rouge_args="-e {}RELEASE-1.5.5/data -a -m -n 2 -d".format(ROUGE_PATH)
    )
    output_dict = r.output_to_dict(output)

    shutil.rmtree(PYROUGE_ROOT)

    scores = {}
    scores["rouge-1"], scores["rouge-2"], scores["rouge-l"] = {}, {}, {}
    scores["rouge-1"]["p"], scores["rouge-1"]["r"], scores["rouge-1"]["f"] = (
        output_dict["rouge_1_precision"],
        output_dict["rouge_1_recall"],
        output_dict["rouge_1_f_score"],
    )
    scores["rouge-2"]["p"], scores["rouge-2"]["r"], scores["rouge-2"]["f"] = (
        output_dict["rouge_2_precision"],
        output_dict["rouge_2_recall"],
        output_dict["rouge_2_f_score"],
    )
    scores["rouge-l"]["p"], scores["rouge-l"]["r"], scores["rouge-l"]["f"] = (
        output_dict["rouge_l_precision"],
        output_dict["rouge_l_recall"],
        output_dict["rouge_l_f_score"],
    )
    return scores


def pyrouge_score_all_multi(hyps_list, refer_list, remap=True):
    from pyrouge import Rouge155

    nowTime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    PYROUGE_ROOT = os.path.join("./", nowTime)
    SYSTEM_PATH = os.path.join(PYROUGE_ROOT, "system")
    MODEL_PATH = os.path.join(PYROUGE_ROOT, "gold")
    if os.path.exists(SYSTEM_PATH):
        shutil.rmtree(SYSTEM_PATH)
    os.makedirs(SYSTEM_PATH)
    if os.path.exists(MODEL_PATH):
        shutil.rmtree(MODEL_PATH)
    os.makedirs(MODEL_PATH)

    assert len(hyps_list) == len(refer_list)
    for i in range(len(hyps_list)):
        system_file = os.path.join(SYSTEM_PATH, "Model.%d.txt" % i)
        # model_file = os.path.join(MODEL_PATH, 'Reference.A.%d.txt' % i)

        hyps = clean(hyps_list[i]) if remap else hyps_list[i]

        with open(system_file, "wb") as f:
            f.write(hyps.encode("utf-8"))

        referType = ["A", "B", "C", "D", "E", "F", "G"]
        for j in range(len(refer_list[i])):
            model_file = os.path.join(
                MODEL_PATH, "Reference.%s.%d.txt" % (referType[j], i)
            )
            refer = clean(refer_list[i][j]) if remap else refer_list[i][j]
            with open(model_file, "wb") as f:
                f.write(refer.encode("utf-8"))

    r = Rouge155(ROUGE_PATH + "RELEASE-1.5.5")

    r.system_dir = SYSTEM_PATH
    r.model_dir = MODEL_PATH
    r.system_filename_pattern = "Model.(\d+).txt"
    r.model_filename_pattern = "Reference.[A-Z].#ID#.txt"

    output = r.convert_and_evaluate(
        rouge_args="-e {}RELEASE-1.5.5/data -a -m -n 2 -d".format(ROUGE_PATH)
    )
    output_dict = r.output_to_dict(output)

    shutil.rmtree(PYROUGE_ROOT)

    scores = {}
    scores["rouge-1"], scores["rouge-2"], scores["rouge-l"] = {}, {}, {}
    scores["rouge-1"]["p"], scores["rouge-1"]["r"], scores["rouge-1"]["f"] = (
        output_dict["rouge_1_precision"],
        output_dict["rouge_1_recall"],
        output_dict["rouge_1_f_score"],
    )
    scores["rouge-2"]["p"], scores["rouge-2"]["r"], scores["rouge-2"]["f"] = (
        output_dict["rouge_2_precision"],
        output_dict["rouge_2_recall"],
        output_dict["rouge_2_f_score"],
    )
    scores["rouge-l"]["p"], scores["rouge-l"]["r"], scores["rouge-l"]["f"] = (
        output_dict["rouge_l_precision"],
        output_dict["rouge_l_recall"],
        output_dict["rouge_l_f_score"],
    )
    return scores


def cal_label(article, abstract):
    hyps_list = article

    refer = abstract
    scores = []
    for hyps in hyps_list:
        mean_score = rouge_eval(hyps, refer)
        scores.append(mean_score)

    selected = []
    selected.append(int(np.argmax(scores)))
    selected_sent_cnt = 1

    best_rouge = np.max(scores)
    while selected_sent_cnt < len(hyps_list):
        cur_max_rouge = 0.0
        cur_max_idx = -1
        for i in range(len(hyps_list)):
            if i not in selected:
                temp = copy.deepcopy(selected)
                temp.append(i)
                hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)])
                cur_rouge = rouge_eval(hyps, refer)
                if cur_rouge > cur_max_rouge:
                    cur_max_rouge = cur_rouge
                    cur_max_idx = i
        if cur_max_rouge != 0.0 and cur_max_rouge >= best_rouge:
            selected.append(cur_max_idx)
            selected_sent_cnt += 1
            best_rouge = cur_max_rouge
        else:
            break

    # label = np.zeros(len(hyps_list), dtype=int)
    # label[np.array(selected)] = 1
    # return list(label)
    return selected


def cal_label_limited3(article, abstract):
    hyps_list = article

    refer = abstract
    scores = []
    for hyps in hyps_list:
        try:
            mean_score = rouge_eval(hyps, refer)
            scores.append(mean_score)
        except ValueError:
            scores.append(0.0)

    selected = []
    selected.append(np.argmax(scores))
    selected_sent_cnt = 1

    best_rouge = np.max(scores)
    while selected_sent_cnt < len(hyps_list) and selected_sent_cnt < 3:
        cur_max_rouge = 0.0
        cur_max_idx = -1
        for i in range(len(hyps_list)):
            if i not in selected:
                temp = copy.deepcopy(selected)
                temp.append(i)
                hyps = "\n".join([hyps_list[idx] for idx in np.sort(temp)])
                cur_rouge = rouge_eval(hyps, refer)
                if cur_rouge > cur_max_rouge:
                    cur_max_rouge = cur_rouge
                    cur_max_idx = i
        selected.append(cur_max_idx)
        selected_sent_cnt += 1
        best_rouge = cur_max_rouge

    # logger.info(selected)
    # label = np.zeros(len(hyps_list), dtype=int)
    # label[np.array(selected)] = 1
    # return list(label)
    return selected


import torch


def flip(x, dim):
    xsize = x.size()
    dim = x.dim() + dim if dim < 0 else dim
    x = x.contiguous()
    x = x.view(-1, *xsize[dim:]).contiguous()
    x = x.view(x.size(0), x.size(1), -1)[
        :,
        getattr(
            torch.arange(x.size(1) - 1, -1, -1), ("cpu", "cuda")[x.is_cuda]
        )().long(),
        :,
    ]
    return x.view(xsize)


def get_attn_key_pad_mask(seq_k, seq_q):
    """ For masking out the padding part of key sequence. """

    # Expand to fit the shape of key query attention matrix.
    len_q = seq_q.size(1)
    padding_mask = seq_k.eq(0.0)
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk

    return padding_mask


def get_non_pad_mask(seq):

    assert seq.dim() == 2

    return seq.ne(0.0).type(torch.float).unsqueeze(-1)
